#include "inflate.h"
#include "system_error.h"

#include <boost/iostreams/filter/zlib.hpp>
#include <boost/iostreams/filtering_streambuf.hpp>
#include <boost/iostreams/copy.hpp>

#include <functional>
#include <random>
#include <string>
#include <vector>

#define CATCH_CONFIG_MAIN

#include <catch2/catch.hpp>

namespace {

const unsigned DEFAULT_WINDOW_BITS = 15;

class Options {
public:
    Options() :
            windowBits_(DEFAULT_WINDOW_BITS),
            hasMoreInput_(false) {
    }

    Options& hasMoreInput(bool hasMore = true) {
        hasMoreInput_ = hasMore;
        return *this;
    }

    bool hasMoreInput() const {
        return hasMoreInput_;
    }

    Options& windowBits(unsigned count) {
        windowBits_ = count;
        return *this;
    }

    unsigned windowBits() const {
        return windowBits_;
    }

private:
    unsigned windowBits_;
    bool hasMoreInput_;
};

class Output {
public:
    explicit Output(std::string* str) :
            str_(str) {
    }

    int append(const char* data, size_t size) {
        str_->append(data, size);
        return size;
    }

private:
    std::string* str_;
};

class Inflate {
public:
    typedef std::function<size_t(const char*, size_t, Output*)> OutputFn;

    Inflate() :
            ctx_(nullptr) {
        init();
    }

    ~Inflate() {
        destroy();
    }

    void init(const Options& opts = Options()) {
        destroy();
        inflate_opts inflOpts = {};
        inflOpts.window_bits = opts.windowBits();
        const int r = inflate_create(&ctx_, &inflOpts, outputCallback, this);
        REQUIRE(r == 0);
    }

    void destroy() {
        inflate_destroy(ctx_);
        outputData_ = std::string();
        outputFn_ = OutputFn();
        ctx_ = nullptr;
    }

    int input(const char* data, size_t* size, const Options& opts = Options()) {
        REQUIRE(ctx_ != nullptr);
        return inflate_input(ctx_, data, size, opts.hasMoreInput() ? INFLATE_HAS_MORE_INPUT : 0);
    }

    int input(const char* data, size_t size, const Options& opts = Options()) {
        return input(data, &size, opts);
    }

    void reset() {
        REQUIRE(ctx_ != nullptr);
        inflate_reset(ctx_);
        outputData_ = std::string();
    }

    const std::string& output() const {
        return outputData_;
    }

    void outputFn(OutputFn fn) {
        outputFn_ = std::move(fn);
    }

    inflate_ctx* instance() const {
        return ctx_;
    }

private:
    OutputFn outputFn_;
    std::string outputData_;
    inflate_ctx* ctx_;

    static int outputCallback(const char* data, size_t size, void* userData) {
        auto self = (Inflate*)userData;
        if (self->outputFn_) {
            Output out(&self->outputData_);
            return self->outputFn_(data, size, &out);
        } else {
            self->outputData_.append(data, size);
        }
        return size;
    }
};

std::string deflate(const std::string& data, const Options& opts = Options()) {
    using namespace boost::iostreams;

    std::istringstream src(data);
    std::ostringstream dest;
    filtering_ostreambuf filter;
    zlib_params params;
    params.window_bits = opts.windowBits();
    params.noheader = true; // Do not add a zlib header
    filter.push(zlib_compressor(params));
    filter.push(dest);
    copy(src, filter);
    return dest.str();
}

std::default_random_engine& randomGen() {
    static thread_local std::default_random_engine gen((std::random_device())());
    return gen;
}

size_t randomSize(size_t min, size_t max) {
    std::uniform_int_distribution<unsigned> dist(min, max);
    return dist(randomGen());
}

std::string genRandomData(size_t size) {
    std::uniform_int_distribution<unsigned> dist(0, 255);
    std::string d;
    d.reserve(size);
    for (size_t i = 0; i < size; ++i) {
        d += (char)dist(randomGen());
    }
    return d;
}

std::string genRandomData(size_t minSize, size_t maxSize) {
    return genRandomData(randomSize(minSize, maxSize));
}

std::string genRandomData() {
    return genRandomData(50000, 1000000);
}

std::string genCompressibleData(size_t size) {
    // The data generated by this function can be compressed by 25-30% given that `size`
    // is large enough (>= ~50K)
    const size_t minChunkSize = 5;
    const size_t maxChunkSize = 80;
    const size_t chunkCount = 50;
    std::vector<std::string> chunks;
    chunks.reserve(chunkCount);
    for (size_t i = 0; i < chunkCount; ++i) {
        chunks.push_back(genRandomData(minChunkSize, maxChunkSize));
    }
    std::uniform_int_distribution<unsigned> dist(0, chunkCount * 3);
    std::string d;
    d.reserve(size);
    while (d.size() < size) {
        const size_t index = dist(randomGen());
        if (index < chunks.size()) {
            d += chunks[index];
        } else {
            d += genRandomData(minChunkSize, maxChunkSize);
        }
    }
    d.resize(size);
    return d;
}

std::string genCompressibleData(size_t minSize, size_t maxSize) {
    return genCompressibleData(randomSize(minSize, maxSize));
}

std::string genCompressibleData() {
    return genCompressibleData(50000, 1000000);
}

int dummyOutputCallback(const char* data, size_t size, void* userData) {
    return size;
}

} // namespace

TEST_CASE("inflate_create()") {
    SECTION("creates a decompressor instance") {
        inflate_ctx* ctx = nullptr;
        int r = inflate_create(&ctx, nullptr, dummyOutputCallback, nullptr);
        CHECK(r == 0);
        CHECK(ctx != nullptr);
        inflate_destroy(ctx);
    }

    SECTION("can be configured with a specific window size") {
        inflate_ctx* ctx = nullptr;
        inflate_opts opts = {};
        opts.window_bits = INFLATE_MAX_WINDOW_BITS;
        int r = inflate_create(&ctx, &opts, dummyOutputCallback, nullptr);
        CHECK(r == 0);
        CHECK(ctx != nullptr);
        inflate_destroy(ctx);
    }

    SECTION("fails if the number of window bits is out of range") {
        inflate_ctx* ctx = nullptr;
        inflate_opts opts = {};
        opts.window_bits = 7;
        int r = inflate_create(&ctx, &opts, dummyOutputCallback, nullptr);
        CHECK(r == SYSTEM_ERROR_INVALID_ARGUMENT);
        CHECK(ctx == nullptr);
        opts.window_bits = 16;
        r = inflate_create(&ctx, &opts, dummyOutputCallback, nullptr);
        CHECK(r == SYSTEM_ERROR_INVALID_ARGUMENT);
        CHECK(ctx == nullptr);
    }

    SECTION("fails if the output callback is NULL") {
        inflate_ctx* ctx = nullptr;
        int r = inflate_create(&ctx, nullptr, nullptr, nullptr);
        CHECK(r == SYSTEM_ERROR_INVALID_ARGUMENT);
        CHECK(ctx == nullptr);
    }
}

TEST_CASE("inflate_input()") {
    Inflate infl;

    SECTION("can process the entire compressed data in one pass") {
        auto decomp = genCompressibleData();
        auto comp = deflate(decomp);
        size_t size = comp.size();
        int r = infl.input(comp.data(), &size);
        CHECK(r == INFLATE_DONE);
        CHECK(size == comp.size());
        CHECK(infl.output() == decomp);
    }

    SECTION("expects more input when INFLATE_HAS_MORE_INPUT is set") {
        auto decomp = genCompressibleData();
        auto comp = deflate(decomp);
        size_t size = comp.size() / 2;
        int r = infl.input(comp.data(), &size, Options().hasMoreInput());
        CHECK(r == INFLATE_NEEDS_MORE_INPUT);
        CHECK(size == comp.size() / 2);
        comp = comp.substr(comp.size() / 2);
        size = comp.size();
        r = infl.input(comp.data(), comp.size());
        CHECK(r == INFLATE_DONE);
        CHECK(size == comp.size());
        CHECK(infl.output() == decomp);
    }

    SECTION("fails when compressed stream ends unexpectedly") {
        auto decomp = genCompressibleData();
        auto comp = deflate(decomp);
        int r = infl.input(comp.data(), comp.size() / 2);
        CHECK(r == SYSTEM_ERROR_BAD_DATA);
    }

    SECTION("can process input data in chunks of arbitrary size") {
        auto decomp = genCompressibleData();
        auto comp = deflate(decomp);
        int r = 0;
        size_t offs = 0;
        do {
            auto data = comp.data() + offs;
            size_t chunkSize = std::min(randomSize(50, 500), comp.size() - offs);
            size_t size = chunkSize;
            // INFLATE_HAS_MORE_INPUT is set for all but the last chunk
            bool hasMore = (offs + size < comp.size());
            r = infl.input(data, &size, Options().hasMoreInput(hasMore));
            REQUIRE(size == chunkSize);
            offs += size;
        } while (r == INFLATE_NEEDS_MORE_INPUT);
        CHECK(r == INFLATE_DONE);
        CHECK(infl.output() == decomp);
    }

    SECTION("can process input data in 1-byte chunks") {
        auto decomp = genCompressibleData(50000);
        auto comp = deflate(decomp);
        int r = 0;
        size_t offs = 0;
        do {
            auto data = comp.data() + offs;
            size_t size = 1;
            auto hasMore = (offs + size < comp.size());
            r = infl.input(data, &size, Options().hasMoreInput(hasMore));
            REQUIRE(size == 1);
            offs += size;
        } while (r == INFLATE_NEEDS_MORE_INPUT);
        CHECK(r == INFLATE_DONE);
        CHECK(infl.output() == decomp);
    }

    SECTION("fails to process malformed input data") {
        auto comp = genRandomData(1024);
        auto size = comp.size();
        int r = infl.input(comp.data(), &size);
        CHECK(r == SYSTEM_ERROR_BAD_DATA);
        infl.reset();
        // The first half of the data is well-formed
        auto decomp = genCompressibleData();
        comp = deflate(decomp);
        r = infl.input(comp.data(), comp.size() / 2, Options().hasMoreInput());
        CHECK(r == INFLATE_NEEDS_MORE_INPUT);
        comp = genRandomData(comp.size() / 2);
        r = infl.input(comp.data(), comp.size());
        CHECK(r == SYSTEM_ERROR_BAD_DATA);
    }

    SECTION("works as expected when the output is consumed in chunks of arbitrary size") {
        auto decomp = genCompressibleData(50000);
        auto comp = deflate(decomp);
        unsigned count = 0;
        infl.outputFn([&count](const char* data, size_t size, Output* out) {
            if (++count % 5) {
                size = std::min(randomSize(10, 100), size);
            } else {
                size = 0;
            }
            return out->append(data, size);
        });
        int r = 0;
        size_t offs = 0;
        do {
            auto data = comp.data() + offs;
            size_t size = comp.size() - offs;
            r = infl.input(data, &size);
            offs += size;
        } while (r == INFLATE_HAS_MORE_OUTPUT);
        CHECK(r == INFLATE_DONE);
        CHECK(infl.output() == decomp);
    }

    SECTION("works as expected when the output is consumed in 1-byte chunks") {
        auto decomp = genCompressibleData(50000);
        auto comp = deflate(decomp);
        infl.outputFn([](const char* data, size_t size, Output* out) {
            return out->append(data, 1);
        });
        int r = 0;
        size_t offs = 0;
        do {
            auto data = comp.data() + offs;
            size_t size = comp.size() - offs;
            r = infl.input(data, &size);
            offs += size;
        } while (r == INFLATE_HAS_MORE_OUTPUT);
        CHECK(r == INFLATE_DONE);
        CHECK(infl.output() == decomp);
    }

    SECTION("can decompress a stream of unknown size") {
        auto decomp = genCompressibleData();
        auto comp = deflate(decomp);
        int r = 0;
        size_t offs = 0;
        size_t size = 0;
        do {
            auto data = comp.data() + offs;
            size_t chunkSize = std::min<size_t>(128, comp.size() - offs);
            size = chunkSize;
            // INFLATE_HAS_MORE_INPUT is set for all chunks
            r = infl.input(data, &size, Options().hasMoreInput());
            REQUIRE(size == chunkSize);
            offs += size;
        } while (r == INFLATE_NEEDS_MORE_INPUT);
        CHECK(r == INFLATE_DONE);
        CHECK(infl.output() == decomp);
    }

    SECTION("works as expected with a smaller window size") {
        auto decomp = genCompressibleData();
        auto comp = deflate(decomp, Options().windowBits(10));
        infl.init(Options().windowBits(10));
        size_t size = comp.size();
        int r = infl.input(comp.data(), &size);
        CHECK(r == INFLATE_DONE);
        CHECK(size == comp.size());
        CHECK(infl.output() == decomp);
    }

    SECTION("stress test") {
        for (unsigned i = 0; i < 500; ++i) {
            auto decomp = genCompressibleData(10000, 100000);
            auto comp = deflate(decomp);
            int r = 0;
            size_t offs = 0;
            size_t chunkSize = randomSize(50, 500);
            do {
                auto data = comp.data() + offs;
                size_t size = std::min(chunkSize, comp.size() - offs);
                bool hasMore = (offs + size < comp.size());
                r = infl.input(data, size, Options().hasMoreInput(hasMore));
                offs += size;
            } while (r == INFLATE_NEEDS_MORE_INPUT);
            REQUIRE(r == INFLATE_DONE);
            REQUIRE(infl.output() == decomp);
            infl.reset();
        }
    }
}
